{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Overlaying heatmaps on images\n",
    "Note: these will only run once you've run `test.py` on the corresponding test sets, using the `--visualize` option.\n",
    "e.g. see `scripts/04_eval_visualize_gen_models.sh`, as it pulls the heatmap data from the respective experiment directories. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "from utils import imutil, show, renormalize\n",
    "import cv2\n",
    "from scipy.ndimage.filters import gaussian_filter\n",
    "import os\n",
    "from PIL import ImageFilter\n",
    "\n",
    "import torch\n",
    "from torchvision import transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_overlays(path, blur_sigma=1, threshold=0.5, normalize=None):\n",
    "    image = Image.open(path)\n",
    "    image = image.resize((128, 128), Image.LANCZOS)\n",
    "    \n",
    "    show.a(['original', image], cols=4)\n",
    "    heatmap = np.load(path.replace('orig.png', 'heatmap_1.npz'))['heatmap']\n",
    "    image_np = np.array(image)\n",
    "    \n",
    "    if normalize is None:\n",
    "        # don't normalize if the heatmap is basically uniform\n",
    "        # to avoid div by zero errors\n",
    "        normalize = True if np.max(heatmap) - np.min(heatmap) > 0.001 else False\n",
    "    # print(\"Normalize?: %s\" % normalize)\n",
    "    \n",
    "    direction = 'below' if '/fakes/' in path else 'above' \n",
    "    overlay_contour = Image.fromarray(imutil.overlay_blur(\n",
    "        image_np, heatmap, normalize, blur_sigma, False, True, threshold, direction))\n",
    "    show.a(['contour im', overlay_contour], cols=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_heatmaps(prefix, blur_sigma=(2, 2), size=(128, 128)):    \n",
    "    easiest_fakes = gaussian_filter(np.load(os.path.join(prefix, 'vis/fakes/easiest/heatmap_avg.npz'))['heatmap'], sigma=blur_sigma[0])\n",
    "    easiest_reals = gaussian_filter(np.load(os.path.join(prefix, 'vis/reals/easiest/heatmap_avg.npz'))['heatmap'], sigma=blur_sigma[1])\n",
    "    show.a(['easiest fakes', Image.fromarray(imutil.colorize_heatmap(easiest_fakes, normalize=True)).resize(size)])\n",
    "    show.a(['easiest reals', Image.fromarray(imutil.colorize_heatmap(easiest_reals, normalize=True)).resize(size)])\n",
    "    easiest = (easiest_reals + 1 - easiest_fakes) / 2\n",
    "    show.a(['easiest avg', Image.fromarray(imutil.colorize_heatmap(easiest, normalize=True)).resize(size)])\n",
    "    show.flush()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## overlays a heatmap over visualized images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pgan pretrained\n",
    "prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celebahq-pgan-pretrained/'\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/010_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/026_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/099_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/096_orig.png')\n",
    "draw_overlays(path)\n",
    "draw_heatmaps(prefix)\n",
    "show.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# celebahq stylegan pretrained\n",
    "prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/celebahq-sgan-pretrained/'\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/068_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/048_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/019_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/023_orig.png')\n",
    "draw_overlays(path)\n",
    "draw_heatmaps(prefix, blur_sigma=(2, 2))\n",
    "show.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# glow model\n",
    "prefix = '../results/gp1d-gan-samplesonly_seed0_xception_block1_constant_p50/test/epoch_bestval/celebahq-glow-pretrained/'\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/002_orig.png')\n",
    "draw_overlays(path, blur_sigma=2)\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/007_orig.png')\n",
    "draw_overlays(path, blur_sigma=2)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/002_orig.png')\n",
    "draw_overlays(path, blur_sigma=2)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/009_orig.png')\n",
    "draw_overlays(path, blur_sigma=2)\n",
    "draw_heatmaps(prefix)\n",
    "show.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gmm model\n",
    "prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/celeba-gmm'\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/031_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/022_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/006_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/009_orig.png')\n",
    "draw_overlays(path)\n",
    "draw_heatmaps(prefix)\n",
    "show.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ffhq pgan 9k \n",
    "prefix = '../results/gp1-gan-winversion_seed0_xception_block2_constant_p20/test/epoch_bestval/ffhq-pgan/'\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/022_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/019_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/003_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/030_orig.png')\n",
    "draw_overlays(path)\n",
    "draw_heatmaps(prefix)\n",
    "show.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ffhq sgan2 pretrained \n",
    "prefix = '../results/gp1-gan-winversion_seed0_xception_block3_constant_p10/test/epoch_bestval/ffhq-sgan2'\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/047_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/fakes/easiest/035_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/001_orig.png')\n",
    "draw_overlays(path)\n",
    "path = os.path.join(prefix, 'vis/reals/easiest/055_orig.png')\n",
    "draw_overlays(path)\n",
    "draw_heatmaps(prefix, blur_sigma=(2, 0))\n",
    "show.flush()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Face Forensics\n",
    "You can do a similar experiment on FaceForensics dataset, but you'll have to preprocess the frames first according to `scripts/00_data_processing_faceforensics_aligned_frames.sh` and then run the evaluation script following `scripts/04_eval_visualize_faceforensics_F2F.sh`, for example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  # F2F test on F2F\n",
    "# prefix = '../results/gp5-faceforensics-f2f_baseline_resnet18_layer1/test/epoch_bestval/F2F/'\n",
    "# path = os.path.join(prefix, 'vis/reals/easiest/001_orig.png')\n",
    "# draw_overlays(path, blur_sigma=2)\n",
    "# path = os.path.join(prefix, 'vis/reals/easiest/012_orig.png')\n",
    "# draw_overlays(path,  blur_sigma=2)\n",
    "# path = os.path.join(prefix, 'vis/fakes/easiest/001_orig.png')\n",
    "# draw_overlays(path,  blur_sigma=2)\n",
    "# path = os.path.join(prefix, 'vis/fakes/easiest/082_orig.png')\n",
    "# draw_overlays(path,  blur_sigma=2)\n",
    "# draw_heatmaps(prefix)\n",
    "# show.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "forgery",
   "language": "python",
   "name": "forgery"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}